import torch
from torch import nn


class ContextBlock(nn.Module):
    def __init__(self, inplanes=256, ratio=1.0/8, pooling_type='att', fusion_types=('channel_add',)):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x, support):
        # x 查询特征，support 支持特征
        # x [1,256,16,16]
        # support [1,256,16,16]
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # x [N, C, H, W]
            input_x = input_x.view(batch, channel, height * width)
            # input_x [N, C, H * W]

            input_x = input_x.unsqueeze(1)
            # input_x [N, 1, C, H * W]

            context_mask = self.conv_mask(support)
            # context_mask [N, 1, H, W]

            context_mask = context_mask.view(batch, 1, height * width)
            # context_mask [N, 1, H * W]

            context_mask = self.softmax(context_mask)
            # context_mask [N, 1, H * W]

            context_mask = context_mask.unsqueeze(-1)
            # context_mask [N, 1, H * W, 1]

            context = torch.matmul(input_x, context_mask)
            # input_x [N, 1, C, H * W], context_mask [N, 1, H * W, 1]
            # context [N, 1, C, 1]

            context = context.view(batch, channel, 1, 1)
            # [N, C, 1, 1]
        else:

            context = self.avg_pool(x)
            # [N, C, 1, 1]
        return context

    def forward(self, x, support):

        context = self.spatial_pool(x, support)
        # [N, C, 1, 1]
        out = x
        # x [N, C, H, W]
        if self.channel_mul_conv is not None:
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            # [N, C, 1, 1]
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            channel_add_term = self.channel_add_conv(context)
            # [N, C, 1, 1]
            out = out + channel_add_term
            # out [N, C, H, W]
        return out


class OldContextBlock(nn.Module):
    def __init__(self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add',)):
        super(OldContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # x [N, C, H, W]
            input_x = input_x.view(batch, channel, height * width)
            # input_x [N, C, H * W]

            input_x = input_x.unsqueeze(1)
            # input_x [N, 1, C, H * W]

            context_mask = self.conv_mask(x)
            # context_mask [N, 1, H, W]

            context_mask = context_mask.view(batch, 1, height * width)
            # context_mask [N, 1, H * W]

            context_mask = self.softmax(context_mask)
            # context_mask [N, 1, H * W]

            context_mask = context_mask.unsqueeze(-1)
            # context_mask [N, 1, H * W, 1]

            context = torch.matmul(input_x, context_mask)
            # input_x [N, 1, C, H * W], context_mask [N, 1, H * W, 1]
            # context [N, 1, C, 1]

            context = context.view(batch, channel, 1, 1)
            # [N, C, 1, 1]
        else:

            context = self.avg_pool(x)
            # [N, C, 1, 1]
        return context

    def forward(self, x):

        context = self.spatial_pool(x)
        # [N, C, 1, 1]
        out = x
        # x [N, C, H, W]
        if self.channel_mul_conv is not None:
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            # [N, C, 1, 1]
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            channel_add_term = self.channel_add_conv(context)
            # [N, C, 1, 1]
            out = out + channel_add_term
        return out


if __name__ == "__main__":
    # in_tensor = torch.ones((12, 64, 128, 128))
    # cb = OldContextBlock(inplanes=64, ratio=1. / 16., pooling_type='att')
    # out_tensor = cb(in_tensor)

    x = torch.ones((1, 256, 16, 16))
    s = torch.ones((1, 256, 16, 16))

    cb = ContextBlock(256, 1.0 / 4, pooling_type='att')
    out = cb(x, s)
    print('out', out.shape)
